{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "# default_exp tutorial\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "import shutil\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"\n",
    "\n",
    "# clean working dir\n",
    "if os.path.exists('./models'):\n",
    "    shutil.rmtree('./models')\n",
    "\n",
    "if os.path.exists('./tmp'):\n",
    "    shutil.rmtree('./tmp')\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import tensorflow as tf\n",
    "tf.autograph.set_verbosity(0)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Tutorial\n",
    "\n",
    "In this tutorial, we'll start with the most basic example to get you up and running the model as quickly and easily as possible. Then we'll dive into some more complicated example and hopefully you'll get some insight of what happened behind the scene."
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Minimal Example\n",
    "\n",
    "In this example, we'll create train, eval and predict toy problems. But first, we need to what dose problem mean here. Essentially, a problem should have **a name(string), a problem type(string), and a preprocessing function(callable)**. The following problem type is pre-defined:\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "# hide\n",
    "from m3tl.params import Params\n",
    "import pprint\n",
    "params = Params()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "source": [
    "for problem_type in params.list_available_problem_types():\n",
    "    print('`{problem_type}`: {desc}'.format(\n",
    "        desc=params.problem_type_desc[problem_type], problem_type=problem_type))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "`cls`: Classification\n",
      "`multi_cls`: Multi-Label Classification\n",
      "`seq_tag`: Sequence Labeling\n",
      "`masklm`: Masked Language Model\n",
      "`pretrain`: NSP+MLM(Deprecated)\n",
      "`regression`: Regression\n",
      "`vector_fit`: Vector Fitting\n",
      "`premask_mlm`: Pre-masked Masked Language Model\n",
      "`contrastive_learning`: Contrastive Learning\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Normally, you would want to use this library to do multi-task learning. There are two types of chaining operations can be used to chain problems.\n",
    "\n",
    "- `&`. If two problems have the same inputs, they can be chained using `&`. Problems chained by `&` will be trained at the same time.\n",
    "- `|`. If two problems don't have the same inputs, they need to be chained using `|`. Problems chained by `|` will be sampled to train at every instance.\n",
    "\n",
    "> Note: chaining problems with `&` works better with pyspark pre-processing and providing `inputs_record_id` key. For more information, please refer to [Write More Flexible Preprocessing Function](#Write-More-Flexible-Preprocessing-Function).\n",
    "\n",
    "If your problem dose not fall in the pre-defined problem types, you can implement your own and register to params. We will cover this topic later. Let's start with a simple example of adding a classification problem and a sequence labeling problem."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "source": [
    "# define toy problems name and problem type\n",
    "problem_type_dict = {'toy_cls': 'cls', 'toy_seq_tag': 'seq_tag'}"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Then we need to do some coding. We need to implement preprocessing function for each problem. The preprocessing function is a callable with \n",
    "\n",
    "- same name as problem name\n",
    "- fixed input signature \n",
    "- returns(or yield) inputs and targets\n",
    "- decorated by `m3tl.preproc_decorator.preprocessing_fn`"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "source": [
    "# define a simple preprocessing function\n",
    "import m3tl\n",
    "from m3tl.preproc_decorator import preprocessing_fn\n",
    "from m3tl.params import Params\n",
    "from m3tl.special_tokens import TRAIN\n",
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str):\n",
    "    \"Simple example to demonstrate singe modal tuple of list return\"\n",
    "    if mode == TRAIN:\n",
    "        toy_input = ['this is a test' for _ in range(10)]\n",
    "        toy_target = ['a' if i <=5 else 'b' for i in range(10)]\n",
    "    else:\n",
    "        toy_input = ['this is a test' for _ in range(10)]\n",
    "        toy_target = ['a' if i <=5 else 'b' for i in range(10)]\n",
    "    return toy_input, toy_target\n",
    "\n",
    "@preprocessing_fn\n",
    "def toy_seq_tag(params: Params, mode: str):\n",
    "    \"Simple example to demonstrate singe modal tuple of list return\"\n",
    "    if mode == TRAIN:\n",
    "        toy_input = ['this is a test'.split(' ') for _ in range(10)]\n",
    "        toy_target = [['a', 'b', 'c', 'd'] for _ in range(10)]\n",
    "    else:\n",
    "        toy_input = ['this is a test'.split(' ') for _ in range(10)]\n",
    "        toy_target = [['a', 'b', 'c', 'd'] for _ in range(10)]\n",
    "    return toy_input, toy_target\n",
    "\n",
    "processing_fn_dict = {'toy_cls': toy_cls, 'toy_seq_tag': toy_seq_tag}"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Now we're good to go! Since these two toy problems shares the same input, we can chain them with `&`. "
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "source": [
    "# collapse_output\n",
    "from m3tl.run_bert_multitask import train_bert_multitask, eval_bert_multitask, predict_bert_multitask\n",
    "problem = 'toy_cls&toy_seq_tag'\n",
    "# train\n",
    "model = train_bert_multitask(\n",
    "    problem=problem,\n",
    "    num_epochs=1,\n",
    "    problem_type_dict=problem_type_dict,\n",
    "    processing_fn_dict=processing_fn_dict,\n",
    "    continue_training=False\n",
    ")\n"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:07:25.768 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls\n",
      "2021-06-24 16:07:25.769 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag\n",
      "2021-06-24 16:07:25.770 | WARNING  | m3tl.base_params:prepare_dir:363 - bert_config not exists. will load model from huggingface checkpoint.\n",
      "2021-06-24 16:07:27.851 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:258 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-24 16:07:27.853 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: this is a test\n",
      "2021-06-24 16:07:27.854 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:07:27.854 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:07:27.855 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:07:27.855 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:07:27.856 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_cls_label_ids: 0\n",
      "2021-06-24 16:07:27.862 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']\n",
      "2021-06-24 16:07:27.862 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:07:27.863 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:07:27.863 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:07:27.863 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:07:27.864 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_seq_tag_label_ids: [0, 1, 2, 3, 4, 0]\n",
      "2021-06-24 16:07:27.867 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing tmp/toy_cls_toy_seq_tag/train_00000.tfrecord\n",
      "2021-06-24 16:07:27.898 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:258 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.\n",
      "2021-06-24 16:07:27.899 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: this is a test\n",
      "2021-06-24 16:07:27.900 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:07:27.900 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:07:27.901 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:07:27.901 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:07:27.901 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_cls_label_ids: 0\n",
      "2021-06-24 16:07:27.905 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']\n",
      "2021-06-24 16:07:27.906 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:07:27.906 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:07:27.907 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:07:27.907 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:07:27.907 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_seq_tag_label_ids: [0, 1, 2, 3, 4, 0]\n",
      "2021-06-24 16:07:27.910 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing tmp/toy_cls_toy_seq_tag/eval_00000.tfrecord\n",
      "2021-06-24 16:07:28.601 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:07:28.602 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n",
      "2021-06-24 16:07:28.750 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:07:28.751 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n",
      "2021-06-24 16:07:28.943 | CRITICAL | m3tl.base_params:update_train_steps:456 - Updating train_steps to 1\n",
      "2021-06-24 16:07:29.062 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:07:29.063 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n",
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Some layers from the model checkpoint at bert-base-chinese were not used when initializing TFBertModel: ['mlm___cls', 'nsp___cls']\n",
      "- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the layers of TFBertModel were initialized from the model checkpoint at bert-base-chinese.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n",
      "2021-06-24 16:07:33.091 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: \n",
      " {\n",
      "    \"text\": 0\n",
      "}\n",
      "2021-06-24 16:07:33.278 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "2021-06-24 16:07:33.289 | CRITICAL | m3tl.model_fn:compile:271 - Initial lr: 0.0\n",
      "2021-06-24 16:07:33.290 | CRITICAL | m3tl.model_fn:compile:272 - Train steps: 1\n",
      "2021-06-24 16:07:33.290 | CRITICAL | m3tl.model_fn:compile:273 - Warmup steps: 0\n",
      "2021-06-24 16:07:33.496 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
      "2021-06-24 16:07:50.903 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "1/1 [==============================] - ETA: 0s - mean_acc: 0.9431 - toy_cls_acc: 0.6000 - toy_seq_tag_acc: 0.0857 - BertMultiTaskTop/toy_cls/losses/0: 0.7577 - BertMultiTaskTop/toy_seq_tag/losses/0: 2.3291"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:08:04.196 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 1000 batches). You may need to use the repeat() function when building your dataset.\n",
      "1/1 [==============================] - 39s 39s/step - mean_acc: 0.9431 - toy_cls_acc: 0.6000 - toy_seq_tag_acc: 0.0857 - BertMultiTaskTop/toy_cls/losses/0: 0.7577 - BertMultiTaskTop/toy_seq_tag/losses/0: 2.3291 - val_loss: 3.1377 - val_mean_acc: 0.3000 - val_toy_cls_acc: 0.6000 - val_toy_seq_tag_acc: 0.0000e+00 - val_BertMultiTaskTop/toy_cls/losses/0: 0.6794 - val_BertMultiTaskTop/toy_seq_tag/losses/0: 2.4584\n",
      "Model: \"BertMultiTask\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "BertMultiTaskBody (BertMulti multiple                  102268416 \n",
      "_________________________________________________________________\n",
      "basic_mtl (BasicMTL)         multiple                  0         \n",
      "_________________________________________________________________\n",
      "BertMultiTaskTop (BertMultiT multiple                  5387      \n",
      "_________________________________________________________________\n",
      "sum_loss_combination (SumLos multiple                  0         \n",
      "=================================================================\n",
      "Total params: 102,273,805\n",
      "Trainable params: 102,273,799\n",
      "Non-trainable params: 6\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "For eval, we need to provide `model_dir` or `model` to the function. Please note that the unresolved object warning raised by tensorflow is expected since optimizer's states will not be initialized in evaluation and prediction."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "source": [
    "# collapse_output\n",
    "# eval\n",
    "eval_dict = eval_bert_multitask(problem=problem,\n",
    "                    problem_type_dict=problem_type_dict, processing_fn_dict=processing_fn_dict,\n",
    "                    model_dir=model.params.ckpt_dir)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:08:13.640 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls\n",
      "2021-06-24 16:08:13.641 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag\n",
      "2021-06-24 16:08:13.782 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:08:13.782 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n",
      "2021-06-24 16:08:14.095 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:08:14.096 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n",
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:08:15.113 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: \n",
      " {\n",
      "    \"text\": 0\n",
      "}\n",
      "2021-06-24 16:08:15.189 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
      "2021-06-24 16:08:17.691 | CRITICAL | m3tl.model_fn:compile:271 - Initial lr: 0.0\n",
      "2021-06-24 16:08:17.692 | CRITICAL | m3tl.model_fn:compile:272 - Train steps: 1\n",
      "2021-06-24 16:08:17.693 | CRITICAL | m3tl.model_fn:compile:273 - Warmup steps: 0\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter\n",
      "WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1\n",
      "WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2\n",
      "WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay\n",
      "WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:08:18.193 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "2/2 [==============================] - 8s 8ms/step - loss: 3.1377 - mean_acc: 0.3000 - toy_cls_acc: 0.6000 - toy_seq_tag_acc: 0.0000e+00 - BertMultiTaskTop/toy_cls/losses/0: 0.6794 - BertMultiTaskTop/toy_seq_tag/losses/0: 2.4584\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "source": [
    "print(eval_dict)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{'loss': 3.1377058029174805, 'mean_acc': 0.30000001192092896, 'toy_cls_acc': 0.6000000238418579, 'toy_seq_tag_acc': 0.0, 'BertMultiTaskTop/toy_cls/losses/0': 0.6793524026870728, 'BertMultiTaskTop/toy_seq_tag/losses/0': 2.458353281021118}\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "source": [
    "# collapse_output\n",
    "# predict\n",
    "fake_inputs = ['this is a test'.split(' ') for _ in range(10)]\n",
    "pred, model = predict_bert_multitask(\n",
    "    problem=problem,\n",
    "    inputs=fake_inputs, model_dir=model.params.ckpt_dir,\n",
    "    problem_type_dict=problem_type_dict,\n",
    "    processing_fn_dict=processing_fn_dict, return_model=True)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:08:26.800 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "2021-06-24 16:08:26.801 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls\n",
      "2021-06-24 16:08:26.802 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag\n",
      "2021-06-24 16:08:26.822 | INFO     | m3tl.run_bert_multitask:predict_bert_multitask:464 - Checkpoint dir: models/toy_cls_toy_seq_tag_ckpt\n",
      "2021-06-24 16:08:29.839 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']\n",
      "2021-06-24 16:08:29.840 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:08:29.841 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:08:29.841 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:08:29.841 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:08:29.891 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']\n",
      "2021-06-24 16:08:29.892 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:08:29.892 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:08:29.893 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:08:29.894 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:08:29.914 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']\n",
      "2021-06-24 16:08:29.915 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:08:29.915 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:08:29.916 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:08:29.916 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:08:30.835 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: \n",
      " {\n",
      "    \"text\": 0\n",
      "}\n",
      "2021-06-24 16:08:30.907 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
      "2021-06-24 16:08:33.300 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
      "2021-06-24 16:08:39.912 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']\n",
      "2021-06-24 16:08:39.913 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:08:39.913 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]\n",
      "2021-06-24 16:08:39.914 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:08:39.914 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "`pred` is a dictionary with problem name as key and probability distribution array as value."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "source": [
    "for problem_name, prob_array in pred.items():\n",
    "    print(f'{problem_name} - {prob_array.shape}')"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "toy_cls - (10, 2)\n",
      "toy_seq_tag - (10, 7, 5)\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Use Different Models\n",
    "\n",
    "By default, we use Bert as the base model. But thanks to transformers, it's easy to switch to any SOTA transformers models with some simple configuration and pass the params to train function as an argument. "
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "source": [
    "# hide_output\n",
    "# change model to distilbert-base-uncased\n",
    "from m3tl.params import Params\n",
    "params = Params()\n",
    "# specify model and its loading module\n",
    "params.transformer_model_name = 'distilbert-base-uncased'\n",
    "params.transformer_model_loading = 'TFDistilBertModel'\n",
    "# specify tokenizer and its loading module\n",
    "params.transformer_tokenizer_name = 'distilbert-base-uncased'\n",
    "params.transformer_tokenizer_loading = 'DistilBertTokenizer'\n",
    "# specify config and its loading module\n",
    "params.transformer_config_name = 'distilbert-base-uncased'\n",
    "params.transformer_config_loading = 'DistilBertConfig'\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Besides the \"body\" model, we can also set mtl model. By default, it will be hard parameter sharing, but we have implemented various mtl models. To see what's available, use"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "source": [
    "import json\n",
    "print(json.dumps(params.list_available_mtl_setup(), indent=4))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{\n",
      "    \"available_mtl_model\": [\n",
      "        \"basic\",\n",
      "        \"mmoe\"\n",
      "    ],\n",
      "    \"available_problem_sampling_strategy\": [],\n",
      "    \"available_loss_combination_strategy\": [\n",
      "        \"sum\"\n",
      "    ],\n",
      "    \"available_gradient_surgery\": []\n",
      "}\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "source": [
    "# collapse_output\n",
    "# train model with mmoe\n",
    "params.assign_mtl_model('mmoe')\n",
    "model = train_bert_multitask(\n",
    "    problem=problem,\n",
    "    num_epochs=1,\n",
    "    problem_type_dict=problem_type_dict,\n",
    "    processing_fn_dict=processing_fn_dict,\n",
    "    continue_training=False,\n",
    "    params=params # pass params\n",
    ")\n"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:08:41.917 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls\n",
      "2021-06-24 16:08:41.918 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag\n",
      "2021-06-24 16:08:41.919 | WARNING  | m3tl.base_params:prepare_dir:363 - bert_config not exists. will load model from huggingface checkpoint.\n",
      "2021-06-24 16:08:44.124 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:08:44.125 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n",
      "2021-06-24 16:08:44.279 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:08:44.280 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n",
      "2021-06-24 16:08:44.467 | CRITICAL | m3tl.base_params:update_train_steps:456 - Updating train_steps to 1\n",
      "2021-06-24 16:08:44.589 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: \n",
      "2021-06-24 16:08:44.589 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {\n",
      "    \"toy_cls_toy_seq_tag\": 1.0\n",
      "}\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.\n",
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['vocab_transform', 'vocab_projector', 'vocab_layer_norm', 'activation_13']\n",
      "- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.\n",
      "2021-06-24 16:08:47.617 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: \n",
      " {\n",
      "    \"text\": 0\n",
      "}\n",
      "2021-06-24 16:08:47.695 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer\n",
      "2021-06-24 16:08:47.702 | CRITICAL | m3tl.model_fn:compile:271 - Initial lr: 0.0\n",
      "2021-06-24 16:08:47.703 | CRITICAL | m3tl.model_fn:compile:272 - Train steps: 1\n",
      "2021-06-24 16:08:47.703 | CRITICAL | m3tl.model_fn:compile:273 - Warmup steps: 0\n",
      "2021-06-24 16:08:47.874 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n",
      "2021-06-24 16:08:55.282 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "1/1 [==============================] - ETA: 0s - mean_acc: 0.7434 - toy_cls_acc: 0.4000 - toy_seq_tag_acc: 0.2714 - BertMultiTaskTop/toy_cls/losses/0: 0.6950 - BertMultiTaskTop/toy_seq_tag/losses/0: 1.6072"
     ]
    },
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:09:05.034 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval\n",
      "The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).\n",
      "The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 1000 batches). You may need to use the repeat() function when building your dataset.\n",
      "1/1 [==============================] - 22s 22s/step - mean_acc: 0.7434 - toy_cls_acc: 0.4000 - toy_seq_tag_acc: 0.2714 - BertMultiTaskTop/toy_cls/losses/0: 0.6950 - BertMultiTaskTop/toy_seq_tag/losses/0: 1.6072 - val_loss: 2.3018 - val_mean_acc: 0.3429 - val_toy_cls_acc: 0.4000 - val_toy_seq_tag_acc: 0.2857 - val_BertMultiTaskTop/toy_cls/losses/0: 0.6948 - val_BertMultiTaskTop/toy_seq_tag/losses/0: 1.6070\n",
      "Model: \"BertMultiTask\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "BertMultiTaskBody (BertMulti multiple                  66363648  \n",
      "_________________________________________________________________\n",
      "m_mo_e (MMoE)                multiple                  799760    \n",
      "_________________________________________________________________\n",
      "BertMultiTaskTop (BertMultiT multiple                  907       \n",
      "_________________________________________________________________\n",
      "sum_loss_combination_3 (SumL multiple                  0         \n",
      "=================================================================\n",
      "Total params: 67,164,317\n",
      "Trainable params: 67,164,311\n",
      "Non-trainable params: 6\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Write More Flexible Preprocessing Function\n",
    "\n",
    "The most simple preprocessing function returns tuple of list, inputs and labels, as we shown above. However, inputs can get pretty complicated when doing multi-modal multi-task learning. In this case, we can use dictionary to store our data with some magic keys: \n",
    "\n",
    "- `\"inputs_\"` and `\"labels_\"` prefix. We still divide the preprocessing output into inputs and labels. By adding `\"inputs_\"` and `\"labels_\"` prefix to the dictionary keys, the module will correctly handle them in train, eval and predict.\n",
    "- `\"_modal_type\"` and `\"_modal_info\"` suffix. Adding these suffix will indicate the modal type of some inputs. If they're not provided, the module will try to infer the correct information from data.\n",
    "- `i`. If specified, this key will be used to join problems chained with `&`. It is required if any problems are chained with `&`.\n",
    "\n",
    "Example:\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "source": [
    "from m3tl.predefined_problems.test_data import generate_fake_data\n",
    "gen = generate_fake_data(output_format='gen_dict')\n",
    "pprint.pprint(next(gen))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{'inputs_array': array([0.89512351, 0.89110354, 0.70502249, 0.23868364, 0.40018975,\n",
      "       0.52657185, 0.87574078, 0.08114504, 0.93732932, 0.24289513]),\n",
      " 'inputs_cate': 0,\n",
      " 'inputs_cate_modal_info': 1,\n",
      " 'inputs_cate_modal_type': 'category',\n",
      " 'inputs_record_id': 0,\n",
      " 'inputs_text': 'this is a test',\n",
      " 'labels': 'a'}\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Local Preprocessing\n",
    "\n",
    "You can return a list of dictionary or a generator of dictionary from your preprocessing function. \n",
    "\n",
    "> Important: If you return a generator of dictionary, you have to call `m3tl.utils.get_or_make_label_encoder` within your preprocessing function!!!\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "source": [
    "from m3tl.utils import get_or_make_label_encoder\n",
    "from m3tl.special_tokens import TRAIN\n",
    "import inspect\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "source": [
    "# collapse_output\n",
    "params.num_cpus = 1\n",
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str):\n",
    "    # IMPORTANT!\n",
    "    get_or_make_label_encoder(\n",
    "        params=params,\n",
    "        problem=inspect.currentframe().f_code.co_name, # current function name\n",
    "        mode=mode,\n",
    "        label_list=['a', 'b'],\n",
    "        overwrite=True\n",
    "    )\n",
    "    return generate_fake_data(output_format='gen_dict')\n",
    "\n",
    "params.register_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)\n",
    "\n",
    "# then you can call the preproc function and take a look at the result\n",
    "pprint.pprint(next(toy_cls(params, TRAIN)))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "2021-06-24 16:09:15.822 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - record_id: 0\n",
      "2021-06-24 16:09:15.823 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: this is a test\n",
      "2021-06-24 16:09:15.824 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - array: [0.96637881 0.49298023 0.38897724 0.36710049 0.36735467 0.28640803\n",
      " 0.39647259 0.30369951 0.35238779 0.05860911]\n",
      "2021-06-24 16:09:15.825 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - cate: 0\n",
      "2021-06-24 16:09:15.825 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - cate_modal_type: category\n",
      "2021-06-24 16:09:15.825 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - cate_modal_info: 1\n",
      "2021-06-24 16:09:15.826 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - record_id_modal_type: category\n",
      "2021-06-24 16:09:15.826 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text\n",
      "2021-06-24 16:09:15.827 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - array_modal_type: array\n",
      "2021-06-24 16:09:15.827 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - record_id: 0\n",
      "2021-06-24 16:09:15.827 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 2023, 2003, 1037, 3231, 102]\n",
      "2021-06-24 16:09:15.828 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]\n",
      "2021-06-24 16:09:15.828 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]\n",
      "2021-06-24 16:09:15.829 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_cls_label_ids: 0\n",
      "2021-06-24 16:09:15.829 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - array_input_ids: [[0.96637881 0.49298023 0.38897724 0.36710049 0.36735467 0.28640803\n",
      "  0.39647259 0.30369951 0.35238779 0.05860911]]\n",
      "2021-06-24 16:09:15.830 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - array_mask: [1]\n",
      "2021-06-24 16:09:15.830 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - array_segment_ids: [0]\n",
      "2021-06-24 16:09:15.831 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - cate_input_ids: [0]\n",
      "2021-06-24 16:09:15.831 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - cate_mask: [1]\n",
      "2021-06-24 16:09:15.832 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - cate_segment_ids: [0]\n"
     ]
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{'array_input_ids': array([[0.96637881, 0.49298023, 0.38897724, 0.36710049, 0.36735467,\n",
      "        0.28640803, 0.39647259, 0.30369951, 0.35238779, 0.05860911]]),\n",
      " 'array_mask': [1],\n",
      " 'array_segment_ids': array([0], dtype=int32),\n",
      " 'cate_input_ids': array([0]),\n",
      " 'cate_mask': [1],\n",
      " 'cate_segment_ids': array([0], dtype=int32),\n",
      " 'record_id': 0,\n",
      " 'text_input_ids': [101, 2023, 2003, 1037, 3231, 102],\n",
      " 'text_mask': [1, 1, 1, 1, 1, 1],\n",
      " 'text_segment_ids': [0, 0, 0, 0, 0, 0],\n",
      " 'toy_cls_label_ids': 0}\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### Pyspark preprocessing(experimental)\n",
    "\n",
    "If your data is too huge to process locally, you can also return a pyspark RDD from your preprocessing function.\n",
    "\n",
    "> Important: You have to call `m3tl.utils.get_or_make_label_encoder` within your preprocessing function when using pyspark preprocessing!!!\n",
    "\n",
    "> Note: `params.pyspark_output_path` must be set if pyspark is enabled.\n",
    "\n",
    "> Note: Local processing and pyspark processing cannot mixed together.\n",
    "\n",
    "If two problems chained with `&` and they only share part of the inputs, returning RDD from preprocessing function is required.\n"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "source": [
    "from m3tl.utils import set_is_pyspark\n",
    "import tempfile"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "source": [
    "set_is_pyspark(True)\n",
    "\n",
    "@preprocessing_fn\n",
    "def toy_cls(params: Params, mode: str):\n",
    "    return generate_fake_data(output_format='rdd')\n",
    "\n",
    "params.register_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)\n",
    "\n",
    "# set pyspark output path\n",
    "params.pyspark_output_path = tempfile.mkdtemp()\n",
    "\n",
    "# then you can call the preproc function and take a look at the result\n",
    "toy_cls_rdd = toy_cls(params, TRAIN)\n",
    "pprint.pprint(toy_cls_rdd.collect()[0])\n"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{'array_input_ids': array([[0.00258592, 0.20750642, 0.25051955, 0.85366103, 0.93457556,\n",
      "        0.05154129, 0.48336023, 0.36393742, 0.40964549, 0.77414554]]),\n",
      " 'array_mask': [1],\n",
      " 'array_segment_ids': array([0], dtype=int32),\n",
      " 'cate_input_ids': array([0]),\n",
      " 'cate_mask': [1],\n",
      " 'cate_segment_ids': array([0], dtype=int32),\n",
      " 'record_id': 0,\n",
      " 'text_input_ids': [101, 2023, 2003, 1037, 3231, 102],\n",
      " 'text_mask': [1, 1, 1, 1, 1, 1],\n",
      " 'text_segment_ids': [0, 0, 0, 0, 0, 0],\n",
      " 'toy_cls_label_ids': 0}\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "### What Happened?\n",
    "\n",
    "The inputs returned by preprocessing function will be tokenized using transformers tokenizer which is configurable like we showed before and the labels will be encoded(or tokenized if the target is text) as scalar or numpy array. The encoded inputs and target then will be serialized and written as TFRecord. Please note that the TFRecord will NOT be overwritten even if you run the code again. So if you want to change the data in TFRecord, you need to manually remove the directory of TFRecord. The default directory is `./tmp/{problem_name}`.\n",
    "\n",
    "After the TFRecord is created, if you want to check the feature info, you can head to the corresponding directory and take a look at the json file within. \n",
    "\n",
    "First, we make sure the TFRecord is created."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "source": [
    "# train_eval_input_fn will create and read the TFRecord, and returns a dataset\n",
    "from m3tl.input_fn import train_eval_input_fn\n",
    "\n",
    "dataset = train_eval_input_fn(params)\n"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "Below is the TFRecord directory tree."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "source": [
    "# hide_input\n",
    "import os\n",
    "\n",
    "def list_files(startpath):\n",
    "    for root, dirs, files in os.walk(startpath):\n",
    "        level = root.replace(startpath, '').count(os.sep)\n",
    "        indent = ' ' * 4 * (level)\n",
    "        print('{}{}/'.format(indent, os.path.basename(root)))\n",
    "        subindent = ' ' * 4 * (level + 1)\n",
    "        for f in files:\n",
    "            print('{}{}'.format(subindent, f))\n",
    "\n",
    "list_files(params.tmp_file_dir)"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "tmp/\n",
      "    toy_cls_toy_seq_tag/\n",
      "        eval_00000.tfrecord\n",
      "        train_feature_desc.json\n",
      "        problem_info.txt\n",
      "        eval_feature_desc.json\n",
      "        train_00000.tfrecord\n"
     ]
    }
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "We can take a look at the json file."
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "source": [
    "import json\n",
    "import os\n",
    "\n",
    "# the problem chained by & create one TFRecord folder\n",
    "json_path = os.path.join(params.tmp_file_dir, 'toy_cls_toy_seq_tag', 'train_feature_desc.json')\n",
    "print(json.dumps(json.load(open(json_path, 'r', encoding='utf8')), indent=4))"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "{\n",
      "    \"text_input_ids\": \"int64\",\n",
      "    \"text_input_ids_shape_value\": [\n",
      "        null\n",
      "    ],\n",
      "    \"text_input_ids_shape\": \"int64\",\n",
      "    \"text_mask\": \"int64\",\n",
      "    \"text_mask_shape_value\": [\n",
      "        null\n",
      "    ],\n",
      "    \"text_mask_shape\": \"int64\",\n",
      "    \"text_segment_ids\": \"int64\",\n",
      "    \"text_segment_ids_shape_value\": [\n",
      "        null\n",
      "    ],\n",
      "    \"text_segment_ids_shape\": \"int64\",\n",
      "    \"toy_cls_label_ids\": \"int64\",\n",
      "    \"toy_cls_label_ids_shape\": \"int64\",\n",
      "    \"toy_cls_label_ids_shape_value\": [],\n",
      "    \"toy_seq_tag_label_ids\": \"int64\",\n",
      "    \"toy_seq_tag_label_ids_shape_value\": [\n",
      "        null\n",
      "    ],\n",
      "    \"toy_seq_tag_label_ids_shape\": \"int64\"\n",
      "}\n"
     ]
    }
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.10 64-bit ('base': conda)",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}